{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow.keras.layers import Dense, Flatten \n",
    "from tensorflow.keras import Model\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Download a dataset\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "# Batch and shuffle the data\n",
    "train_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_train.astype('float32') / 255, y_train)).shuffle(1024).batch(32)\n",
    "\n",
    "test_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_test.astype('float32') / 255, y_test)).batch(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VFLPassiveModel(Model):\n",
    "    def __init__(self):\n",
    "        super(VFLPassiveModel, self).__init__()\n",
    "        self.flatten = Flatten()\n",
    "        self.d1 = Dense(10, name=\"dense1\")\n",
    "\n",
    "    def call(self, x):\n",
    "        x = self.flatten(x)\n",
    "        return self.d1(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VFLActiveModel(Model):\n",
    "    def __init__(self):\n",
    "        super(VFLActiveModel, self).__init__()\n",
    "        self.added = tf.keras.layers.Add()\n",
    "\n",
    "    def call(self, x):\n",
    "        x = self.added(x)\n",
    "        return tf.keras.layers.Softmax()(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 0.396317720413208, Accuracy: 88.98332977294922, Test Loss: 0.2956085503101349, Test Accuracy: 91.50999450683594\n",
      "Epoch 2, Loss: 0.289332777261734, Accuracy: 91.95999908447266, Test Loss: 0.28157228231430054, Test Accuracy: 91.93999481201172\n",
      "Epoch 3, Loss: 0.27582597732543945, Accuracy: 92.39500427246094, Test Loss: 0.28288891911506653, Test Accuracy: 92.04000091552734\n",
      "Epoch 4, Loss: 0.2672642767429352, Accuracy: 92.70500183105469, Test Loss: 0.2774609327316284, Test Accuracy: 92.06999969482422\n",
      "Epoch 5, Loss: 0.26180967688560486, Accuracy: 92.7933349609375, Test Loss: 0.2739064395427704, Test Accuracy: 92.29000091552734\n"
     ]
    }
   ],
   "source": [
    "passive_model = VFLPassiveModel()\n",
    "active_model = VFLActiveModel()\n",
    "\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy()\n",
    "optimizer = tf.keras.optimizers.Adam()\n",
    "\n",
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')\n",
    "\n",
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    # For each batch of images and labels\n",
    "    for images, labels in train_ds:\n",
    "        with tf.GradientTape() as passive_tape:\n",
    "            # passive_model sends passive_output to active_model\n",
    "            passive_output = passive_model(images)\n",
    "            with tf.GradientTape() as active_tape:\n",
    "                active_tape.watch(passive_output)\n",
    "                active_output = active_model([passive_output, passive_output])\n",
    "                loss = loss_object(labels, active_output)\n",
    "            # active_model sends passive_output_gradients back to passive_model\n",
    "            passive_output_gradients = active_tape.gradient(loss, passive_output)\n",
    "            #print(passive_output_gradients)\n",
    "            passive_output_loss = tf.multiply(passive_output, passive_output_gradients.numpy())\n",
    "        passive_weight_gradients = passive_tape.gradient(passive_output_loss, passive_model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(passive_weight_gradients, passive_model.trainable_variables))\n",
    "\n",
    "        train_loss(loss)\n",
    "        train_accuracy(labels, active_output)\n",
    "\n",
    "    for test_images, test_labels in test_ds:\n",
    "        passive_output = passive_model(test_images)\n",
    "        active_output = active_model([passive_output, passive_output])\n",
    "        t_loss = loss_object(test_labels, active_output)\n",
    "\n",
    "        test_loss(t_loss)\n",
    "        test_accuracy(test_labels, active_output)\n",
    "\n",
    "    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'\n",
    "    print(template.format(epoch+1,\n",
    "                        train_loss.result(),\n",
    "                        train_accuracy.result()*100,\n",
    "                        test_loss.result(),\n",
    "                        test_accuracy.result()*100))\n",
    "\n",
    "    # Reset the metrics for the next epoch\n",
    "    train_loss.reset_states()\n",
    "    train_accuracy.reset_states()\n",
    "    test_loss.reset_states()\n",
    "    test_accuracy.reset_states()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
